{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 在网页中应用实时MNIST识别\n",
    "> 小惊喜中带点美中不足吧\n",
    "\n",
    "- toc: true\n",
    "- badges: true\n",
    "- comments: true\n",
    "- hide: true\n",
    "- categories: [tensorflow, gradio]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 前言\n",
    "\n",
    "偶然发现一个包[Gradio](www.gradio.app)，可以快速构建基于机器学习应用的Web交互接口。这里尝试能否使用在Fastpages的Jupyter Notebook文章中。\n",
    "\n",
    "这里直接应用官网的demo之一——实时识别手写数字。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 上手\n",
    "\n",
    "1. 先训练一个MNIST分类器。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n",
      "11493376/11490434 [==============================] - 5s 0us/step\n"
     ]
    },
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x27fa540d9d0>"
      ]
     },
     "metadata": {},
     "execution_count": 1
    }
   ],
   "source": [
    "import tensorflow as tf\n",
    "import gradio as gr\n",
    "\n",
    "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
    "x_train = x_train.astype('float32').reshape(-1, 28, 28, 1) / 255.\n",
    "x_test = x_test.astype('float32').reshape(-1, 28, 28, 1) / 255.\n",
    "\n",
    "# Lenet 5\n",
    "model = tf.keras.Sequential([\n",
    "    tf.keras.layers.Conv2D(filters=6, kernel_size=5, padding='same', activation='relu', input_shape=(28, 28, 1)),\n",
    "    tf.keras.layers.MaxPool2D(pool_size=2),\n",
    "    \n",
    "    tf.keras.layers.Conv2D(filters=16, kernel_size=5, activation='relu'),\n",
    "    tf.keras.layers.MaxPool2D(pool_size=2),\n",
    "    \n",
    "    tf.keras.layers.Conv2D(filters=120, kernel_size=5, activation='relu'),\n",
    "    tf.keras.layers.Flatten(),\n",
    "    tf.keras.layers.Dense(84, activation='relu'),\n",
    "    tf.keras.layers.Dense(10, activation='softmax'),\n",
    "])\n",
    "model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])\n",
    "model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=5, verbose=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2. 构建Gradio界面"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Running locally at: http://127.0.0.1:7860/\n",
      "This share link will expire in 6 hours. If you need a permanent link, email support@gradio.app\n",
      "Running on External URL: https://44307.gradio.app\n",
      "Interface loading below...\n"
     ]
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "<IPython.lib.display.IFrame at 0x27fa59b50d0>",
      "text/html": "\n        <iframe\n            width=\"1000\"\n            height=\"500\"\n            src=\"https://44307.gradio.app\"\n            frameborder=\"0\"\n            allowfullscreen\n        ></iframe>\n        "
     },
     "metadata": {}
    },
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "(<Flask 'gradio.networking'>,\n",
       " 'http://127.0.0.1:7860/',\n",
       " 'https://44307.gradio.app')"
      ]
     },
     "metadata": {},
     "execution_count": 2
    }
   ],
   "source": [
    "# gradio interface need 3 parameters：predict function、input & output components\n",
    "\n",
    "# predict function\n",
    "def classify(image):\n",
    "    image = image.reshape(-1, 28, 28, 1)\n",
    "    prediction = model.predict(image).tolist()[0]\n",
    "    return {str(i): prediction[i] for i in range(10)}\n",
    "\n",
    "# input component\n",
    "sketchpad = gr.inputs.Sketchpad()\n",
    "\n",
    "# output components\n",
    "label = gr.outputs.Label(num_top_classes=3)\n",
    "\n",
    "# generate the UI\n",
    "# live=True: reload changes automatically\n",
    "# capture_sessiong=True: specifically for tensorflow 1.0\n",
    "interface = gr.Interface(classify, sketchpad, label, live=True, capture_session=True)\n",
    "interface.launch(share=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 结果\n",
    "\n",
    "功能都正常实现出来了，虽然识别准确率感人，但关我们主角Gradio什么事。\n",
    "\n",
    "但是Gradio文档展现的api确实有点少，比如本例中UI显示范围受到限制，想要调整UI大小却找不到相应的参数。总的来说Gradio更偏向傻瓜式，应用起来确实快捷简单。\n",
    "\n",
    "* 优点：\n",
    "\n",
    "    + 方便快捷\n",
    "\n",
    "* 缺点：\n",
    "\n",
    "    - api较少，可定义幅度小"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tf",
   "language": "python",
   "name": "tf"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3-final"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}